import os
import difflib
import numpy as np
import tensorflow as tf
import scipy.io.wavfile as wav
from tqdm import tqdm
from scipy.fftpack import fft
# from python_speech_features import mfcc
from random import shuffle
from keras import backend as K
import random
from scipy.fftpack import dct

from Util.utilspacks.calEnergy_calZeroCrossingRate import endPointReTurnNp

fbankDist = {}
mfccDist = {}
fbank2Dist = {}
isCacheFlat = False


def data_hparams():
    # params = tf.contrib.training.HParams(
    #     # vocab
    #     data_type='train',
    #     data_path='',
    #     thchs30=False,
    #     aishell=False,
    #     prime=False,
    #     stcmd=False,
    #     mmcs=False,
    #     hai=False,
    #     ocean=True,
    #     batch_size=1,
    #     data_length=None,
    #     predict=False,
    #     shuffle=True)
    class params:
        def __init__(self):
            self.data_type = 'train'
            self.data_path = ''
            self.mmcs = None
            self.thchs30 = None
            self.aishell = None
            self.prime = None
            self.stcmd = None
            self.ocean = None
            self.hai = None
            self.batch_size = None
            # data_args.data_length = 10000
            self.data_length = None
            self.shuffle = True
            self.predict = False
            self.training = True

    return params()


class get_data():
    def __init__(self, args):
        self.training = args.training
        self.data_type = args.data_type
        self.data_path = args.data_path
        self.thchs30 = args.thchs30
        self.aishell = args.aishell
        self.prime = args.prime
        self.hai = args.hai
        self.ocean = args.ocean

        self.predict = args.predict

        self.countLength = None

        self.starItem = 0
        self.endItem = None

        self.wav_lst_global = None
        self.pny_lst_global = None
        self.han_lst_global = None

        self.pny_vocab = []
        self.am_vocab = []
        self.han_vocab = []
        self.lm_vocab = []
        self.mmcs = args.mmcs
        self.stcmd = args.stcmd
        self.data_length = args.data_length
        self.batch_size = args.batch_size
        self.shuffle = args.shuffle
        if self.training:
            self.source_init()
        self.load_am_token()

    def load_am_token(self):
        print('make am vocab...')
        self.am_vocab = []
        with open('Util/lst/am_tokens.txt', 'r', encoding='utf-8') as f1:
            self.am_vocab = f1.readlines()
        f1.close()
        self.am_vocab = [x.strip() for x in self.am_vocab if x.strip() != '']
        self.am_vocab.append('<PAD>')
        print('make lm vocab...')
        self.lm_vocab = []
        with open('Util/lst/lm_tokens.txt', 'r', encoding='utf-8') as f2:
            self.lm_vocab = f2.readlines()
        f2.close()
        self.lm_vocab = [x.strip() for x in self.lm_vocab if x.strip() != '']
        self.lm_vocab.append('<PAD>')

    def source_init(self):
        print('get source list...')
        read_files = []
        if self.data_type == 'train':
            if self.thchs30:
                read_files.append('thchs_train.txt')
            if self.aishell:
                read_files.append('aishell_train.txt')
            if self.prime:
                read_files.append('prime.txt')
            if self.stcmd:
                read_files.append('stcmd.txt')
            if self.hai:
                read_files.append('hai_train.txt')
            if self.mmcs:
                read_files.append('mmcs_train.txt')
            if self.ocean:
                read_files.append('speechocean_train.txt')
        elif self.data_type == 'dev':
            if self.thchs30:
                read_files.append('thchs_dev.txt')
            if self.aishell:
                read_files.append('aishell_dev.txt')
            if self.hai:
                read_files.append('hai_test.txt')
            if self.mmcs:
                read_files.append('mmcs_dev.txt')
            if self.ocean:
                read_files.append('speechocean_dev.txt')
        elif self.data_type == 'test':
            if self.thchs30:
                read_files.append('thchs_test.txt')
            if self.aishell:
                read_files.append('aishell_test.txt')
            if self.mmcs:
                read_files.append('mmcs_dev.txt')
        print("============导入的文件===========")
        print(read_files)
        print("数量：", len(read_files))
        print("============导入的文件===========")
        self.wav_lst = []
        self.pny_lst = []
        self.han_lst = []
        for file in read_files:
            print('load ', file, ' data...')
            sub_file = 'data/' + file
            with open(sub_file, 'r', encoding='utf8') as f:
                data = f.readlines()
            for line in tqdm(data):
                # print(line)
                wav_file, pny, han = line.split('\t')
                if han == ' ' or han == "" or han == '':
                    continue
                # pny = [i for i in pny if i != '']
                pny = pny.strip()
                han = han.strip('\n')
                han = han.replace(' ', '')
                # han = [i for i in han if i != '']
                self.wav_lst.append(wav_file)
                self.pny_lst.append(pny.split(' '))
                self.han_lst.append(han.strip('\n'))
        self.load_am_token()
        print('make lm pinyin vocab...')
        self.pny_vocab = self.mk_lm_pny_vocab(self.pny_lst)
        print('make lm hanzi vocab...')
        self.han_vocab = self.mk_lm_han_vocab(self.han_lst)
        if self.data_length:
            # 该数据集应该随机选择
            length = len(self.wav_lst)
            self.wav_lst = list(random.sample(self.wav_lst, self.data_length))
            self.pny_lst = list(random.sample(self.pny_lst, self.data_length))
            self.han_lst = list(random.sample(self.han_lst, self.data_length))

    def adjustDataList(self):
        """
            用于迭代训练
        """
        # 每次选择一部分进行训练 加快速度
        self.wav_lst = self.wav_lst_global[self.starItem:self.endItem]
        self.pny_lst = self.pny_lst_global[self.starItem:self.endItem]
        self.han_lst = self.han_lst_global[self.starItem:self.endItem]

        # 清除缓存
        global fbankDist, mfccDist, fbank2Dist
        fbankDist = {}
        mfccDist = {}
        fbank2Dist = {}

    def get_am_batch(self):
        """
        制作声学模型数据迭代器
        :return:
        """
        shuffle_list = [i for i in range(len(self.wav_lst))]
        while 1:
            if self.shuffle:
                shuffle(shuffle_list)
            for i in range(len(self.wav_lst) // self.batch_size):
                wav_data_lst = []
                label_data_lst = []
                begin = i * self.batch_size
                end = begin + self.batch_size
                # if self.data_type == "train":
                #     print("选择:begin="+str(begin)+"，end="+str(end))
                sub_list = shuffle_list[begin:end]
                for index in sub_list:
                    fbank = None
                    # if self.mmcs == True:
                    #     self.data_path = ""
                    if self.predict:
                        fbank = compute_fbank2_predict(self.data_path + self.wav_lst[index])
                    else:
                        fbank = compute_fbank2(self.data_path + self.wav_lst[index])
                    # print(fbank.shape)
                    pad_fbank = np.zeros((fbank.shape[0] // 8 * 8 + 8, fbank.shape[1]))
                    pad_fbank[:fbank.shape[0], :] = fbank
                    label = self.pny2id(self.pny_lst[index], self.am_vocab)
                    label_ctc_len = self.ctc_len(label)
                    if pad_fbank.shape[0] // 8 >= label_ctc_len:
                        wav_data_lst.append(pad_fbank)
                        label_data_lst.append(label)
                pad_wav_data, input_length = self.wav_padding(wav_data_lst)
                pad_label_data, label_length = self.label_padding(label_data_lst)
                inputs = {'the_inputs': pad_wav_data,
                          'the_labels': pad_label_data,
                          'input_length': input_length,
                          'label_length': label_length,
                          }
                outputs = {'ctc': np.zeros(pad_wav_data.shape[0], )}
                yield inputs, outputs

    def get_lm_batch(self):
        """
        获取语言模型数据的迭代器
        :return:
        """
        batch_num = len(self.pny_lst) // self.batch_size
        for k in range(batch_num):
            begin = k * self.batch_size
            end = begin + self.batch_size
            input_batch = self.pny_lst[begin:end]
            label_batch = self.han_lst[begin:end]
            # print(input_batch[0])
            # print(label_batch[0])
            max_len_input = max([len(line) for line in input_batch])
            max_len_label = max([len(line) for line in label_batch])
            max_len = max(max_len_label, max_len_input)
            input_batch = np.array(
                [self.pny2id(line, self.pny_vocab) + [0] * (max_len - len(line)) for line in input_batch])
            label_batch = np.array(
                [self.han2id(line, self.han_vocab) + [0] * (max_len - len(line)) for line in label_batch])
            yield input_batch, label_batch

    def pny2id(self, line, vocab):
        return [vocab.index(pny) for pny in line]

    def han2id(self, line, vocab):
        return [vocab.index(han) for han in line]

    def wav_padding(self, wav_data_lst):
        wav_lens = [len(data) for data in wav_data_lst]
        wav_max_len = max(wav_lens)
        wav_lens = np.array([leng // 8 for leng in wav_lens])
        maxLen = 200
        new_wav_data_lst = np.zeros((len(wav_data_lst), wav_max_len, maxLen, 1))
        for i in range(len(wav_data_lst)):
            new_wav_data_lst[i, :wav_data_lst[i].shape[0], :wav_data_lst[i].shape[1], 0] = wav_data_lst[i]
        return new_wav_data_lst, wav_lens

    def label_padding(self, label_data_lst):
        label_lens = np.array([len(label) for label in label_data_lst])
        max_label_len = max(label_lens)
        new_label_data_lst = np.zeros((len(label_data_lst), max_label_len))
        for i in range(len(label_data_lst)):
            new_label_data_lst[i][:len(label_data_lst[i])] = label_data_lst[i]
        return new_label_data_lst, label_lens

    def mk_am_vocab(self, data):
        vocab = []
        for line in tqdm(data):
            line = line
            for pny in line:
                if pny not in vocab:
                    vocab.append(pny)
        vocab.append('_')
        return vocab

    def mk_lm_pny_vocab(self, data):
        vocab = ['<PAD>']
        for line in tqdm(data):
            for pny in line:
                if pny not in vocab:
                    vocab.append(pny)
        return vocab

    def mk_lm_han_vocab(self, data):
        vocab = ['<PAD>']
        for line in tqdm(data):
            line = ''.join(line.split(' '))
            for han in line:
                if han not in vocab:
                    vocab.append(han)
        return vocab

    def ctc_len(self, label):
        add_len = 0
        label_len = len(label)
        for i in range(label_len - 1):
            if label[i] == label[i + 1]:
                add_len += 1
        return label_len + add_len

    def return_data_types(self):
            return (tf.float32, tf.int32, tf.int32, tf.int32)

    def return_data_shape(self):
        f = 80
        c = 1
        return (
                tf.TensorShape([None, None, 1])
                # if self.speech_config['use_mel_layer'] else tf.TensorShape(
                #     [None, None, f, c]),
                # tf.TensorShape([None, ]),
                # tf.TensorShape([None, None]),
                # tf.TensorShape([None, ])
            )


def compute_mfcc(file):
    fs, audio = wav.read(file)
    mfcc_feat = mfcc(audio, samplerate=fs, numcep=26)
    mfcc_feat = mfcc_feat[::3]
    mfcc_feat = np.transpose(mfcc_feat)
    return mfcc_feat


# 预加重
def H(file, u):
    return np.append(file[0], file[1:] - u * file[:-1])


def compute_fbank_result(file):
    if file in fbankDist.keys():
        return fbankDist[file]
    # x = np.linspace(0, 400 - 1, 400, dtype=np.int64)
    # w = 0.54 - 0.46 * np.cos(2 * np.pi * (x) / (400 - 1))  # 汉明窗
    # print(file)
    fs, wavsignal = wav.read(file)
    # 1、预加重
    u = 0.9375
    wavsignal = H(wavsignal, u)
    # 端点检测
    # 端点检验单独进行
    wavsignal = endPointReTurnNp(wavsignal)
    # 2、分帧
    # wav波形 加时间窗以及时移10ms
    # time_window = 25  # 单位ms
    # wav_arr = np.array(wavsignal)
    frame_size = 0.025
    frame_stride = 0.01
    frame_length, frame_step = frame_size * fs, frame_stride * fs  # Convert from seconds to samples
    signal_length = len(wavsignal)
    frame_length = int(round(frame_length))
    frame_step = int(round(frame_step))
    num_frames = int(np.ceil(
        float(np.abs(signal_length - frame_length)) / frame_step))  # Make sure that we have at least 1 frame

    pad_signal_length = num_frames * frame_step + frame_length
    z = np.zeros((pad_signal_length - signal_length))
    pad_signal = np.append(wavsignal, z)
    # Pad Signal to make sure that all frames have equal number of samples without
    # truncating any samples from the original signal

    indices = np.tile(np.arange(0, frame_length), (num_frames, 1)) + np.tile(
        np.arange(0, num_frames * frame_step, frame_step), (frame_length, 1)).T
    # indices = np.arange(0, frame_length).reshape(1, -1) + np.arange(0, num_frames * frame_step, frame_step).reshape(-1,1)

    frames = pad_signal[indices]
    # 3、加窗
    frames *= np.hanning(frame_length)
    # 4、傅立叶变换和功率谱
    NFFT = 512
    mag_frames = np.absolute(np.fft.rfft(frames, NFFT))  # Magnitude of the FFT
    pow_frames = ((1.0 / NFFT) * (mag_frames ** 2))  # Power Spectrum
    # 5、滤波器
    #    功率谱应用Mel刻度上的三角形滤波器（通常为40个滤波器）
    nfilt = 40

    low_freq_mel = 0
    high_freq_mel = (2595 * np.log10(1 + (fs / 2) / 700))  # Convert Hz to Mel

    mel_points = np.linspace(low_freq_mel, high_freq_mel, nfilt + 2)  # Equally spaced in Mel scale
    hz_points = (700 * (10 ** (mel_points / 2595) - 1))  # Convert Mel to Hz

    fbank = np.zeros((nfilt, int(np.floor(NFFT / 2 + 1))))
    bin = np.floor((NFFT + 1) * hz_points / fs)

    for m in range(1, nfilt + 1):
        f_m_minus = int(bin[m - 1])  # left
        f_m = int(bin[m])  # center
        f_m_plus = int(bin[m + 1])  # right

        for k in range(f_m_minus, f_m):
            fbank[m - 1, k] = (k - bin[m - 1]) / (bin[m] - bin[m - 1])
        for k in range(f_m, f_m_plus):
            fbank[m - 1, k] = (bin[m + 1] - k) / (bin[m + 1] - bin[m])
    filter_banks = np.dot(pow_frames, fbank.T)
    filter_banks = np.where(filter_banks == 0, np.finfo(float).eps, filter_banks)  # Numerical Stability
    filter_banks = 20 * np.log10(filter_banks)  # dB 得到频谱
    # 使用MFCC的时候开启
    # 应用离散余弦变换（DCT）去相关滤波器组系数
    # 因为DCT是线性的 可能会导致部分非线性特征被消除
    # num_ceps = 12
    # cep_lifter = 22
    # mfcc = dct(filter_banks, type=2, axis=1, norm='ortho')[:, 1: (num_ceps + 1)]  # Keep 2-13
    # # 将正弦提升器1应用于MFCC，以降低对较高MFCC的强调，这已被认为可以改善嘈杂信号中的语音识别。
    # (nframes, ncoeff) = mfcc.shape
    # n = np.arange(ncoeff)
    # lift = 1 + (cep_lifter / 2) * np.sin(np.pi * n / cep_lifter)
    # mfcc *= lift  # *
    # # 平均归一化
    # filter_banks -= (np.mean(filter_banks, axis=0) + 1e-8)
    # mfcc -= (np.mean(mfcc, axis=0) + 1e-8)

    # 存入缓存
    fbankDist[file] = filter_banks
    return filter_banks  # 返回FBank特征


def compute_mfcc_result(file):
    if file in mfccDist.keys():
        return mfccDist[file]
    # x = np.linspace(0, 400 - 1, 400, dtype=np.int64)
    # w = 0.54 - 0.46 * np.cos(2 * np.pi * (x) / (400 - 1))  # 汉明窗
    # print(file)
    fs, wavsignal = wav.read(file)
    # 1、预加重
    u = 0.9375
    wavsignal = H(wavsignal, u)
    # 端点检测
    # 端点检验单独进行
    wavsignal = endPointReTurnNp(wavsignal)
    # 2、分帧
    # wav波形 加时间窗以及时移10ms
    # time_window = 25  # 单位ms
    # wav_arr = np.array(wavsignal)
    frame_size = 0.025
    frame_stride = 0.01
    frame_length, frame_step = frame_size * fs, frame_stride * fs  # Convert from seconds to samples
    signal_length = len(wavsignal)
    frame_length = int(round(frame_length))
    frame_step = int(round(frame_step))
    num_frames = int(np.ceil(
        float(np.abs(signal_length - frame_length)) / frame_step))  # Make sure that we have at least 1 frame

    pad_signal_length = num_frames * frame_step + frame_length
    z = np.zeros((pad_signal_length - signal_length))
    pad_signal = np.append(wavsignal, z)
    # Pad Signal to make sure that all frames have equal number of samples without
    # truncating any samples from the original signal

    indices = np.tile(np.arange(0, frame_length), (num_frames, 1)) + np.tile(
        np.arange(0, num_frames * frame_step, frame_step), (frame_length, 1)).T
    # indices = np.arange(0, frame_length).reshape(1, -1) + np.arange(0, num_frames * frame_step, frame_step).reshape(-1,1)

    frames = pad_signal[indices]
    # 3、加窗
    frames *= np.hanning(frame_length)
    # 4、傅立叶变换和功率谱
    NFFT = 512
    mag_frames = np.absolute(np.fft.rfft(frames, NFFT))  # Magnitude of the FFT
    pow_frames = ((1.0 / NFFT) * (mag_frames ** 2))  # Power Spectrum
    # 5、滤波器
    #    功率谱应用Mel刻度上的三角形滤波器（通常为40个滤波器）
    nfilt = 40

    low_freq_mel = 0
    high_freq_mel = (2595 * np.log10(1 + (fs / 2) / 700))  # Convert Hz to Mel

    mel_points = np.linspace(low_freq_mel, high_freq_mel, nfilt + 2)  # Equally spaced in Mel scale
    hz_points = (700 * (10 ** (mel_points / 2595) - 1))  # Convert Mel to Hz

    fbank = np.zeros((nfilt, int(np.floor(NFFT / 2 + 1))))
    bin = np.floor((NFFT + 1) * hz_points / fs)

    for m in range(1, nfilt + 1):
        f_m_minus = int(bin[m - 1])  # left
        f_m = int(bin[m])  # center
        f_m_plus = int(bin[m + 1])  # right

        for k in range(f_m_minus, f_m):
            fbank[m - 1, k] = (k - bin[m - 1]) / (bin[m] - bin[m - 1])
        for k in range(f_m, f_m_plus):
            fbank[m - 1, k] = (bin[m + 1] - k) / (bin[m + 1] - bin[m])
    filter_banks = np.dot(pow_frames, fbank.T)
    filter_banks = np.where(filter_banks == 0, np.finfo(float).eps, filter_banks)  # Numerical Stability
    filter_banks = 20 * np.log10(filter_banks)  # dB 得到频谱
    # 使用MFCC的时候开启
    # 应用离散余弦变换（DCT）去相关滤波器组系数
    # 因为DCT是线性的 可能会导致部分非线性特征被消除
    num_ceps = 12
    cep_lifter = 22
    mfcc = dct(filter_banks, type=2, axis=1, norm='ortho')[:, 1: (num_ceps + 1)]  # Keep 2-13
    # 将正弦提升器1应用于MFCC，以降低对较高MFCC的强调，这已被认为可以改善嘈杂信号中的语音识别。
    (nframes, ncoeff) = mfcc.shape
    n = np.arange(ncoeff)
    lift = 1 + (cep_lifter / 2) * np.sin(np.pi * n / cep_lifter)
    mfcc *= lift  # *
    # 平均归一化
    filter_banks -= (np.mean(filter_banks, axis=0) + 1e-8)
    mfcc -= (np.mean(mfcc, axis=0) + 1e-8)

    # 存入缓存
    mfccDist[file] = mfcc
    return mfcc  # 返回FBank特征


x = np.linspace(0, 400 - 1, 400, dtype=np.int64)
w = 0.54 - 0.46 * np.cos(2 * np.pi * (x) / (400 - 1))  # 汉明窗


def compute_fbank2_predict(file):
    """
        :param file: 音频文件路径
        :return: fbank特征
        0、预加重、端点检测
        1、加窗；
        2、fft傅里叶变换
        """
    if str(file).endswith(".wav"):
        pass
    else:
        raise ValueError("只适用于wav文件")
    if file in fbank2Dist.keys():
        return fbank2Dist[file]
    fs, wavsignal = wav.read(file)
    # 1、预加重
    # u = 0.9375
    # wavsignal = H(wavsignal, u)
    # 2、端点检测
    wavsignal = endPointReTurnNp(wavsignal)
    # wav波形 加时间窗以及时移10ms
    time_window = 25  # 单位ms
    window_length = fs // 1000 * time_window  # 计算窗长度的公式，目前全部为400固定值
    wav_arr = np.array(wavsignal)
    range0_end = int(len(wavsignal) / fs * 1000 - time_window) // 10 + 1  # 计算循环终止的位置，也就是最终生成的窗数
    data_input = np.zeros((range0_end, int(window_length // 2)), dtype=np.float)  # 用于存放最终的频率特征数据
    data_line = np.zeros((1, window_length), dtype=np.float)
    for i in range(0, range0_end):
        p_start = i * 160
        p_end = p_start + 400
        data_line = wav_arr[p_start:p_end]
        data_line = data_line * w  # 加窗
        data_line = np.abs(fft(data_line))
        data_input[i] = data_line[0: window_length // 2]  # 设置为400除以2的值（即200）是取一半数据，因为是对称的
    data_input = np.log(data_input + 1)
    fbank2Dist[file] = data_input
    return data_input


def compute_fbank2(file):
    """
    :param file: 音频文件路径
    :return: fbank特征
    1、加窗；
    2、fft傅里叶变换
    """
    if str(file).endswith(".wav"):
        pass
    else:
        raise ValueError("只适用于wav文件")
    if isCacheFlat:
        if file in fbank2Dist.keys():
            return fbank2Dist[file]
    fs, wavsignal = wav.read(file)
    # wav波形 加时间窗以及时移10ms
    time_window = 25  # 单位ms
    window_length = fs // 1000 * time_window  # 计算窗长度的公式，目前全部为400固定值
    wav_arr = np.array(wavsignal)
    range0_end = int(len(wavsignal) / fs * 1000 - time_window) // 10 + 1  # 计算循环终止的位置，也就是最终生成的窗数
    data_input = np.zeros((range0_end, int(window_length // 2)), dtype=np.float)  # 用于存放最终的频率特征数据
    data_line = np.zeros((1, window_length), dtype=np.float)
    for i in range(0, range0_end):
        p_start = i * 160
        p_end = p_start + 400
        data_line = wav_arr[p_start:p_end]
        data_line = data_line * w  # 加窗
        data_line = np.abs(fft(data_line))
        data_input[i] = data_line[0: window_length // 2]  # 设置为400除以2的值（即200）是取一半数据，因为是对称的
    data_input = np.log(data_input + 1)
    if isCacheFlat:
        fbank2Dist[file] = data_input
    return data_input


# from python_speech_features import delta


# def GetMfccFeature(file):
#     """
#     获取mfcc特征
#     :param wavsignal:
#     :param fs:
#     :return:
#     """
#     if file in mfccDist.keys():
#         return mfccDist[file]
#     fs, wavsignal = wav.read(file)
#     # 获取输入特征
#     feat_mfcc = mfcc(wavsignal, fs)
#     feat_mfcc_d = delta(feat_mfcc, 2)
#     feat_mfcc_dd = delta(feat_mfcc_d, 2)
#     # 返回值分别是mfcc特征向量的矩阵及其一阶差分和二阶差分矩阵
#     wav_feature = np.column_stack((feat_mfcc, feat_mfcc_d, feat_mfcc_dd))
#     mfccDist[file] = wav_feature
#     return wav_feature


# word error rate------------------------------------
def GetEditDistance(str1, str2):
    leven_cost = 0
    s = difflib.SequenceMatcher(None, str1, str2)
    for tag, i1, i2, j1, j2 in s.get_opcodes():
        if tag == 'replace':
            leven_cost += max(i2 - i1, j2 - j1)
        elif tag == 'insert':
            leven_cost += (j2 - j1)
        elif tag == 'delete':
            leven_cost += (i2 - i1)
    return leven_cost


# 定义解码器------------------------------------
def decode_ctc(num_result, num2word):
    result = num_result[:, :, :]
    in_len = np.zeros((1), dtype=np.int32)
    in_len[0] = result.shape[1]
    # print(in_len)
    r = K.ctc_decode(result, in_len, greedy=True, beam_width=10, top_paths=1)
    r1 = K.get_value(r[0][0])
    r1 = r1[0]
    text = []
    for i in r1:
        text.append(num2word[i])
    return r1, text




if __name__ == '__main__':
    import matplotlib.pyplot as plt

    file = "E:/MAGICDATA_Mandarin_Chinese_Speech/MAGICDATA_Mandarin_Chinese_Speech/wav/train/16_3807/16_3807_20170817150902.wav"
    mfcc = compute_fbank(file)
    print(mfcc.shape)
    print(mfcc)
    plt.plot(mfcc)
    plt.show()
